package cn.zifangsky.tree.splaytree;

import cn.zifangsky.stack.LinkStack;
import cn.zifangsky.stack.Stack;

import java.util.List;
import java.util.function.Consumer;

/**
 * 伸展树的定义
 *
 * @author zifangsky
 * @date 2018/11/29
 * @since 1.0.0
 */
public class SplayTree<T extends Comparable<? super T>> {
    /**
     * 伸展树的根节点
     */
    private SplayNode root;

    public SplayTree() {
        root = null;
    }

    /**
     * 插入某个数据
     * @author zifangsky
     * @date 2018/11/28 15:06
     * @since 1.0.0
     * @param data 待插入的数据
     */
    public void insert(T data){
        root = this.insert(root, data);
    }

    /**
     * 批量插入
     * @author zifangsky
     * @date 2018/11/29 14:15
     * @since 1.0.0
     * @param list 待插入的数据
     */
    public void insert(List<T> list){
        if(list != null && list.size() > 0){
            list.forEach(this::insert);
        }
    }

    /**
     * 批量插入
     * @author zifangsky
     * @date 2018/11/29 14:15
     * @since 1.0.0
     * @param values 待插入的数据
     */
    public void insert(T...values){
        if(values != null && values.length > 0){
            for(T data : values){
                this.insert(data);
            }
        }
    }

    /**
     * 删除某个数据
     * @author zifangsky
     * @date 2018/11/28 15:08
     * @since 1.0.0
     * @param data 待删除的数据
     */
    public void remove(T data){
        if(this.isEmpty()){
            throw new RuntimeException("当前伸展树为空，不能执行删除操作");
        }

        root = this.remove(root, data);
    }

    /**
     * 获取当前伸展树的最小元素
     * @author zifangsky
     * @date 2018/11/28 17:39
     * @since 1.0.0
     * @return T
     */
    public T findMin(){
        if(this.isEmpty()){
            throw new RuntimeException("当前伸展树为空，不能执行查找操作");
        }

        root = this.findMin(root);
        return root.element;
    }

    /**
     * 获取当前伸展树的最大元素
     * @author zifangsky
     * @date 2018/11/28 17:39
     * @since 1.0.0
     * @return T
     */
    public T findMax(){
        if(this.isEmpty()){
            throw new RuntimeException("当前伸展树为空，不能执行查找操作");
        }

        root = this.findMax(root);
        return root.element;
    }

    /**
     * 返回某个数据是否在伸展树中
     * @author zifangsky
     * @date 2018/11/28 15:10
     * @since 1.0.0
     * @param data 待检查的数据
     * @return boolean
     */
    public boolean contains(T data){
        return this.contains(root, data);
    }

    /**
     * 清空伸展树
     * @author zifangsky
     * @date 2018/11/28 15:11
     * @since 1.0.0
     */
    public void clear(){
        this.root = null;
    }

    /**
     * 判断伸展树是否为空
     * @author zifangsky
     * @date 2018/11/28 15:13
     * @since 1.0.0
     * @return boolean
     */
    public boolean isEmpty() {
        return this.root == null;
    }

    /**
     * 中序遍历伸展树
     * @author zifangsky
     * @date 2018/11/28 15:18
     * @since 1.0.0
     * @param consumer consumer
     */
    public void forEach(Consumer<T> consumer){
        this.forEach(root, consumer);
    }

    /**
     * 查找某个节点下面最小的节点
     * @param node 指定节点
     */
    private SplayNode findMin(SplayNode node){
        if(node == null){
            return null;
        }

        SplayNode temp = node;
        while (temp.left != null){
            temp = temp.left;
        }

        //将被查询的节点移到根节点
        node = this.splay(node, temp.element);
        return node;
    }

    /**
     * 查找某个节点下面最大的节点
     * @param node 指定节点
     */
    private SplayNode findMax(SplayNode node){
        if(node == null){
            return null;
        }

        SplayNode temp = node;
        while (temp.right != null){
            temp = temp.right;
        }

        //将被查询的节点移到根节点
        node = this.splay(node, temp.element);
        return node;
    }

    /**
     * 返回某个数据是否存在
     * @param node 指定节点
     * @param data 待检查的数据
     */
    private boolean contains(SplayNode node, T data){
        node = this.find(node, data);
        if(node != null){
            root = node;
            return true;
        }else{
            return false;
        }
    }

    /**
     * 查找某个元素，并将这个元素移动到根节点
     * @param node 指定节点
     * @param data 待检查的数据
     */
    private SplayNode find(SplayNode node, T data){
        SplayNode temp = node;
        while (temp != null){
            //比较待检查数据与当前节点谁大谁小
            int compareResult = data.compareTo(temp.element);

            if(compareResult == 0){
                //将被查询的节点移到根节点
                node = this.splay(node, data);
                return node;
            }else if(compareResult < 0){
                temp = temp.left;
            }else{
                temp = temp.right;
            }
        }

        return null;
    }

    /**
     * 插入某个数据
     * @param node 在该节点及其子节点中插入
     * @param data 待插入的数据
     */
    private SplayNode insert(SplayNode node, T data){
        if(node == null){
            return new SplayNode(data);
        }

        SplayNode temp = node;
        //保存一下临时节点
        SplayNode tempNode = temp;
        //新节点被插入到左子树还是右子树
        boolean leftChild = true;

        while (temp != null){
            //比较待检查数据与当前节点谁大谁小
            int compareResult = data.compareTo(temp.element);

            //不插入重复数据
            if(compareResult == 0){
                return node;
            }else if(compareResult < 0){
                //将这个节点保存一下
                tempNode = temp;
                leftChild = true;
                temp = temp.left;
            }else{
                //将这个节点保存一下
                tempNode = temp;
                leftChild = false;
                temp = temp.right;
            }
        }

        if(leftChild){
            tempNode.left = new SplayNode(data);
        }else{
            tempNode.right = new SplayNode(data);
        }

        //将被插入的节点移到根节点
        return this.splay(node, data);
    }

    /**
     * 删除某个数据
     * @param node 在该节点及其子节点中查找并删除数据
     * @param data 待删除的数据
     */
    private SplayNode remove(SplayNode node, T data){
        if(node == null){
            return null;
        }

        //查找被删除元素，并将这个元素移动到根节点
        SplayNode expectedNode = this.find(node, data);
        //查找左孩子节点的最大值，并将这个节点移动到左孩子的位置
        if(expectedNode != null){
            SplayNode leftMaxNode = this.findMax(expectedNode.left);
            //将被删除元素的右子树放到leftMaxNode节点下面
            leftMaxNode.right = expectedNode.right;
            return leftMaxNode;
        }
        return null;
    }

    /**
     * 中序遍历伸展树
     * @param node 遍历该节点及其子节点
     * @param consumer consumer
     */
    private void forEach(SplayNode node, Consumer<T> consumer){
        if(node != null){
            this.forEach(node.left, consumer);
            //处理数据
            consumer.accept(node.element);

            this.forEach(node.right, consumer);
        }
    }

    /**
     * 将被插入/查询的节点移到根节点
     * @param node 被插入/查询的节点
     */
    private SplayNode splay(SplayNode node, T data){
        if(node == null){
            return  null;
        }

        SplayNode temp;
        //定义一个保存节点的栈
        Stack<SplayNode> nodeStack = new LinkStack<>();
        //定义一个走左子树（true）还是右子树（false）的栈
        Stack<Boolean> booleanStack = new LinkStack<>();

        //比较根节点与data的大小
        int compareResult = data.compareTo(node.element);

        while (compareResult != 0){
            //目标节点在左“之字形”
            if(node.left != null && node.left.right != null && data.compareTo(node.left.right.element) == 0){
                //左双旋转
                temp = this.doubleWithLeftChild(node);
                node = this.assign(temp, nodeStack, booleanStack);
                //如果栈不为空，再次出栈
                node = this.assign(node, nodeStack, booleanStack);
            }
            //目标节点在右“之字形”
            else if(node.right != null && node.right.left != null && data.compareTo(node.right.left.element) == 0){
                //右双旋转
                temp = this.doubleWithRightChild(node);
                node = this.assign(temp, nodeStack, booleanStack);
                //如果栈不为空，再次出栈
                node = this.assign(node, nodeStack, booleanStack);
            }
            //目标节点在左“一字形”
            else if(node.left != null && node.left.left != null && data.compareTo(node.left.left.element) == 0){
                temp = this.zigWithLeftChild(node);
                node = this.assign(temp, nodeStack, booleanStack);
                //如果栈不为空，再次出栈
                node = this.assign(node, nodeStack, booleanStack);
            }
            //目标节点在右“一字形”
            else if(node.right != null && node.right.right != null && data.compareTo(node.right.right.element) == 0){
                temp = this.zigWithRightChild(node);
                node = this.assign(temp, nodeStack, booleanStack);
                //如果栈不为空，再次出栈
                node = this.assign(node, nodeStack, booleanStack);
            }
            //目标节点在左孩子节点
            else if(node.left != null && data.compareTo(node.left.element) == 0){
                temp = this.rotateWithLeftChild(node);
                node = this.assign(temp, nodeStack, booleanStack);
            }
            //目标节点在右孩子节点
            else if(node.right != null && data.compareTo(node.right.element) == 0){
                temp = this.rotateWithRightChild(node);
                node = this.assign(temp, nodeStack, booleanStack);
            }
            //目标节点在左子树
            else if(compareResult < 0){
                //入栈
                nodeStack.push(node);
                booleanStack.push(true);
                node = node.left;
            }
            //目标节点在右子树
            else{
                //入栈
                nodeStack.push(node);
                booleanStack.push(false);
                node = node.right;
            }

            compareResult = data.compareTo(node.element);
        }

        return node;
    }

    /**
     * 将调整后的子树挂到它原来的父节点上面
     */
    private SplayNode assign(SplayNode newChild, Stack<SplayNode> nodeStack, Stack<Boolean> booleanStack){
        if(!nodeStack.isEmpty()){
            //出栈取出父节点
            SplayNode node = nodeStack.pop();
            //出栈取出当前子树属于父节点的左孩子还是右孩子
            Boolean leftChild = booleanStack.pop();
            //左孩子
            if(leftChild){
                node.left = newChild;
            }else{
                node.right = newChild;
            }
            return node;
        }else{
            return newChild;
        }
    }

    /**
     * 根据左孩子节点做双旋转
     * @param k3 待旋转子树的根节点
     */
    private SplayNode doubleWithLeftChild(SplayNode k3){
        //先围绕k3的左孩子节点做一次右单旋转
        k3.left = this.rotateWithRightChild(k3.left);
        //再围绕k3做一次左单旋转
        return this.rotateWithLeftChild(k3);
    }

    /**
     * 根据右孩子节点做双旋转
     * @param k1 待旋转子树的根节点
     */
    private SplayNode doubleWithRightChild(SplayNode k1){
        //先围绕k1的右孩子节点做一次左单旋转
        k1.right = this.rotateWithLeftChild(k1.right);
        //再围绕k1做一次右单旋转
        return this.rotateWithRightChild(k1);
    }

    /**
     * 左“一字形”旋转调整
     * @param k3 待旋转子树的根节点
     */
    private SplayNode zigWithLeftChild(SplayNode k3){
        SplayNode k2 = k3.left;
        SplayNode k1 = k2.left;

        k3.left = k2.right;
        k2.right = k3;
        k2.left = k1.right;
        k1.right = k2;

        return k1;
    }

    /**
     * 右“一字形”旋转调整
     * @param k1 待旋转子树的根节点
     */
    private SplayNode zigWithRightChild(SplayNode k1){
        SplayNode k2 = k1.right;
        SplayNode k3 = k2.right;

        k1.right = k2.left;
        k2.left = k1;
        k2.right = k3.left;
        k3.left = k2;

        return k3;
    }

    /**
     * 根据左孩子节点做单旋转
     * @param k2 待旋转子树的根节点
     */
    private SplayNode rotateWithLeftChild(SplayNode k2){
        //k2的左孩子节点
        SplayNode k1 = k2.left;
        //将k1的右孩子挂到k2的左孩子节点
        k2.left = k1.right;
        //将k2挂到k1的右孩子节点
        k1.right = k2;

        return k1;
    }

    /**
     * 根据右孩子节点做单旋转
     * @param k1 待旋转子树的根节点
     */
    private SplayNode rotateWithRightChild(SplayNode k1){
        //k1的右孩子节点
        SplayNode k2 = k1.right;
        //将k2的左孩子挂到k1的右孩子节点
        k1.right = k2.left;
        //将k1挂到k2的左孩子节点
        k2.left = k1;

        return k2;
    }

    /**
     * 定义伸展树的单个节点
     */
    private class SplayNode {
        /**
         * 数据
         */
        T element;
        /**
         * 左孩子节点
         */
        SplayNode left;
        /**
         * 右孩子节点
         */
        SplayNode right;

        SplayNode(T element) {
            this(element, null, null);
        }

        SplayNode(T element, SplayNode left, SplayNode right) {
            this.element = element;
            this.left = left;
            this.right = right;
        }

    }

}
